# ABOUTME: Generates ResidualMomentum following Blitz, Huij and Martens 2011, Table 2B 1M
# ABOUTME: calculates momentum predictor based on FF3 residuals with 36-month rolling regressions

"""
ZZ1_ResidualMomentum6m_ResidualMomentum.py

Usage:
    Run from [Repo-Root]/Signals/pyCode/
    python3 Predictors/ZZ1_ResidualMomentum6m_ResidualMomentum.py

Inputs:
    - monthlyCRSP.parquet: Monthly CRSP data with columns [permno, time_avail_m, ret]
    - monthlyFF.parquet: Fama-French factors with columns [time_avail_m, rf, mktrf, smb, hml]

Outputs:
    - ResidualMomentum6m.csv: CSV file with columns [permno, yyyymm, ResidualMomentum6m]
    - ResidualMomentum.csv: CSV file with columns [permno, yyyymm, ResidualMomentum]
"""

import polars as pl
import polars_ols as pls  # Registers .least_squares namespace
import polars.selectors as cs
import numpy as np
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from utils.save_standardized import save_predictor, save_placebo


print("=" * 80)
print("🏗️  ZZ1_ResidualMomentum6m_ResidualMomentum.py")
print("Generating ResidualMomentum6m and ResidualMomentum predictors")
print("=" * 80)

# DATA LOAD
print("📊 Loading monthly CRSP and Fama-French data...")

# Load monthly CRSP data (permno, time_avail_m, ret)
print("Loading monthlyCRSP.parquet...")
crsp = pl.read_parquet("../pyData/Intermediate/monthlyCRSP.parquet").select(
    ["permno", "time_avail_m", "ret"]
)
print(f"Loaded CRSP: {len(crsp):,} monthly observations")

# Load monthly FF factors and merge
print("Loading monthlyFF.parquet...")
ff = pl.read_parquet("../pyData/Intermediate/monthlyFF.parquet").select(
    ["time_avail_m", "rf", "mktrf", "hml", "smb"]
)
print(f"Loaded FF factors: {len(ff):,} monthly observations")

# Merge CRSP and FF data (equivalent to Stata's merge m:1 ... keep(match))
print("Merging CRSP and FF data...")
df = crsp.join(ff, on="time_avail_m", how="inner")
print(f"Merged dataset: {len(df):,} observations")

# SIGNAL CONSTRUCTION
print("\n🔧 Starting signal construction...")

# Calculate excess returns: retrf = ret - rf
print("Calculating excess returns (retrf = ret - rf)...")
df = df.with_columns((pl.col("ret") - pl.col("rf")).alias("retrf"))


# Sort by permno and time_avail_m (important for time series operations)
df = df.sort(["permno", "time_avail_m"])

# Create time_temp = _n by permno (position-based indexing like Stata)
print("Creating time_temp position index by permno...")
df = df.with_columns(pl.int_range(pl.len()).over("permno").alias("time_temp"))


print(
    "Running rolling 36-observation FF3 regressions by permno using direct polars-ols helper..."
)
print("Processing", df["permno"].n_unique(), "unique permnos...")

# Use asreg helper for rolling FF3 regression with residuals
# Rolling 36-observation windows with minimum 36 observations (exact Stata asreg match)
# Use time_temp (position-based) to match Stata's approach exactly
# Sort by permno and time_temp for deterministic window order
df = df.sort(["permno", "time_temp"])

df = df.with_columns(
    pl.col("retrf")
    .least_squares.rolling_ols(
        pl.col("mktrf"),
        pl.col("hml"),
        pl.col("smb"),
        window_size=36,
        min_periods=36,
        mode="residuals",
        add_intercept=True,
        null_policy="drop",
    )
    .over("permno")
    .alias("_residuals")
)


print(f"Completed rolling regressions for {len(df):,} observations")


# Calculate lagged residuals and rolling momentum signals using pure Polars
print("Calculating lagged residuals and momentum signals...")
df = df.with_columns(
    [
        # Lag residuals by 1 observation: temp = l1._residuals
        pl.col("_residuals")
        .shift(1)
        .over("permno")
        .alias("temp")
    ]
)


df = df.with_columns(
    [
        # 6-observation rolling statistics (position-based, min 6 observations)
        pl.col("temp")
        .rolling_mean(window_size=6, min_samples=6)
        .over("permno")
        .alias("mean6_temp"),
        pl.col("temp")
        .rolling_std(window_size=6, min_samples=6, ddof=1)
        .over("permno")
        .alias("sd6_temp"),
        # 11-observation rolling statistics (position-based, min 11 observations)
        pl.col("temp")
        .rolling_mean(window_size=11, min_samples=11)
        .over("permno")
        .alias("mean11_temp"),
        pl.col("temp")
        .rolling_std(window_size=11, min_samples=11, ddof=1)
        .over("permno")
        .alias("sd11_temp"),
    ]
)


df = df.with_columns(
    [
        # Calculate momentum signals
        (pl.col("mean6_temp") / pl.col("sd6_temp")).alias("ResidualMomentum6m"),
        (pl.col("mean11_temp") / pl.col("sd11_temp")).alias("ResidualMomentum"),
    ]
)

print("Calculating 6-observation and 11-observation rolling momentum signals...")


# Display signal summary statistics
print("\n📈 Signal summary statistics:")
print(
    f"ResidualMomentum6m - Mean: {df.select(pl.col('ResidualMomentum6m').mean()).item():.4f}, Std: {df.select(pl.col('ResidualMomentum6m').std()).item():.4f}"
)
print(
    f"ResidualMomentum - Mean: {df.select(pl.col('ResidualMomentum').mean()).item():.4f}, Std: {df.select(pl.col('ResidualMomentum').std()).item():.4f}"
)
print(
    f"Non-missing ResidualMomentum6m: {df.select(pl.col('ResidualMomentum6m').is_not_null().sum()).item():,}"
)
print(
    f"Non-missing ResidualMomentum: {df.select(pl.col('ResidualMomentum').is_not_null().sum()).item():,}"
)


# SAVE
print("\n💾 Saving signals...")

# Convert to pandas for save functions (they expect pandas DataFrames)
final_df = df.select(
    ["permno", "time_avail_m", "ResidualMomentum6m", "ResidualMomentum"]
).to_pandas()

# Save ResidualMomentum6m as placebo
save_placebo(
    final_df[["permno", "time_avail_m", "ResidualMomentum6m"]], "ResidualMomentum6m"
)

# Save ResidualMomentum as predictor
save_predictor(
    final_df[["permno", "time_avail_m", "ResidualMomentum"]], "ResidualMomentum"
)

print("\n" + "=" * 80)
print("✅ ZZ1_ResidualMomentum6m_ResidualMomentum.py completed successfully")
print("Generated 2 signals:")
print("  • ResidualMomentum6m: 6 month residual momentum (Placebo)")
print("  • ResidualMomentum: Momentum based on FF3 residuals (Predictor)")
print("=" * 80)
